import math
from typing import Any

from django.core.exceptions import ValidationError

from .base import CVSS_REGEX

CVSS3_METRICS_BASE = {
    'AV': {'N': 0.85, 'A': 0.62, 'L': 0.55, 'P': 0.2},
    'AC': {'L': 0.77, 'H': 0.44},
    'PR': {'N': {'U': 0.85, 'C': 0.85}, 'L': {'U': 0.62, 'C': 0.68}, 'H': {'U': 0.27, 'C': 0.5}},
    'UI': {'N': 0.85, 'R': 0.62},
    'S': {'U': 'U', 'C': 'C'},
    'C': {'N': 0, 'L': 0.22, 'H': 0.56},
    'I': {'N': 0, 'L': 0.22, 'H': 0.56},
    'A': {'N': 0, 'L': 0.22, 'H': 0.56},
}
CVSS3_METRICS_TEMPORAL = {
    'E': {'X': 1, 'H': 1, 'F': 0.97, 'P': 0.94, 'U': 0.91},
    'RL': {'X': 1, 'U': 1, 'W': 0.97, 'T': 0.96, 'O': 0.95},
    'RC': {'X': 1, 'C': 1, 'R': 0.96, 'U': 0.92},
}
CVSS3_METRICS_ENVIRONMENTAL = {
    'CR': {'X': 1, 'L': 0.5, 'M': 1, 'H': 1.5},
    'IR': {'X': 1, 'L': 0.5, 'M': 1, 'H': 1.5},
    'AR': {'X': 1, 'L': 0.5, 'M': 1, 'H': 1.5},
    'MAV': {'X': None, 'N': 0.85, 'A': 0.62, 'L': 0.55, 'P': 0.2},
    'MAC': {'X': None, 'L': 0.77, 'H': 0.44},
    'MPR': {'X': None, 'N': {'U': 0.85, 'C': 0.85}, 'L': {'U': 0.62, 'C': 0.68}, 'H': {'U': 0.27, 'C': 0.5}},
    'MUI': {'X': None, 'N': 0.85, 'R': 0.62},
    'MS': {'X': None, 'U': 'U', 'C': 'C'},
    'MC': {'X': None, 'N': 0, 'L': 0.22, 'H': 0.56},
    'MI': {'X': None, 'N': 0, 'L': 0.22, 'H': 0.56},
    'MA': {'X': None, 'N': 0, 'L': 0.22, 'H': 0.56},
}
CVSS3_METRICS = CVSS3_METRICS_BASE | CVSS3_METRICS_TEMPORAL | CVSS3_METRICS_ENVIRONMENTAL


def parse_cvss3(vector, version='3.0'):
    """
    Parses CVSS3.0 and CVSS3.1 vectors.
    For CVSS 3.0 and 3.1 the metrics are the same. Only descriptions and definitions changed.
    """
    if not vector or not CVSS_REGEX.match(vector) or not vector.startswith('CVSS:' + version):
        raise ValidationError(f'Invalid CVSS:{version} vector: Invalid format')

    # parse CVSS metrics
    values = dict(map(lambda p: tuple(p.split(':')),
                  filter(None, vector[8:].split('/'))))
    for k, v in values.items():
        if k not in CVSS3_METRICS or v not in CVSS3_METRICS[k]:
            raise ValidationError(
                f'Invalid CVSS:{version} vector: invalid metric value "{k}:{v}"')

    # Validate required metrics
    for m in CVSS3_METRICS_BASE.keys():
        if m not in values:
            raise ValidationError(
                f'Invalid CVSS{version} vector: base metric "{m}" missing')

    return values


def is_cvss3_0(vector):
    try:
        parse_cvss3(vector, version='3.0')
        return True
    except ValidationError:
        return False


def is_cvss3_1(vector):
    try:
        parse_cvss3(vector, version='3.1')
        return True
    except ValidationError:
        return False


def round_up(input):
    int_input = round(input * 100000)
    if int_input % 10000 == 0:
        return int_input / 100000.0
    else:
        return (math.floor(int_input / 10000) + 1) / 10.0


def calculate_score_cvss3_1(vector) -> dict | None:
    try:
        values = parse_cvss3(vector, version='3.1')
    except ValidationError:
        return None

    def has_metric_group(group):
        return any(map(lambda m: m in values and values[m] != 'X', group.keys()))

    def metric(name, modified=False) -> Any:
        # First try modified metric, then original metric, then X (Not Definied)
        if modified:
            m = CVSS3_METRICS.get('M' + name, {}).get(values.get('M' + name))
            if m is not None and m != 'X':
                return m
        m = CVSS3_METRICS.get(name, {}).get(values.get(name))
        if m is not None:
            return m
        return CVSS3_METRICS.get(name, {}).get('X')

    result: dict[str, str | dict] = {
        "version": "3.1",
    }

    # Environmental score
    m_scope_changed = metric('S', modified=True) == 'C'
    miss = min(1 - (
        (1 - metric('C', modified=True) * metric('CR')) *
        (1 - metric('I', modified=True) * metric('IR')) *
        (1 - metric('A', modified=True) * metric('AR'))
    ), 0.915)
    m_impact = 7.52 * (miss - 0.029) - 3.25 * pow(miss * 0.9731 - 0.02, 13) if m_scope_changed else \
        6.42 * miss
    m_exploitability = 8.22 * metric('AV', modified=True) * metric('AC', modified=True) * metric(
        'PR', modified=True)[metric('S', modified=True)] * metric('UI', modified=True)
    env_score = 0.0 if m_impact <= 0 else (
        round_up(round_up(min(1.08 * (m_impact + m_exploitability), 10)) * metric('E') * metric('RL') * metric('RC')) if m_scope_changed else
        round_up(round_up(min(m_impact + m_exploitability, 10))
                 * metric('E') * metric('RL') * metric('RC'))
    )
    result["environmental"] = {
        "score": env_score,
        "exploitability": m_exploitability,
        "impact": m_impact,
    }

    # Base score
    scope_changed = metric('S') == 'C'
    iss = 1 - ((1 - metric('C')) * (1 - metric('I')) * (1 - metric('A')))
    impact = (7.52 * (iss - 0.029) - 3.25 * pow(iss - 0.02, 15)) if scope_changed else \
        6.42 * iss
    exploitability = 8.22 * \
        metric('AV') * metric('AC') * metric('PR')[metric('S')] * metric('UI')
    score = 0.0 if impact <= 0 else (
        round_up(min(1.08 * (impact + exploitability), 10)) if scope_changed else
        round_up(min(impact + exploitability, 10))
    )
    result["base"] = {
        "score": score,
        "exploitability": exploitability,
        "impact": impact,
    }

    # Temporal score
    if has_metric_group(CVSS3_METRICS_TEMPORAL):
        score = round_up(score * metric('E') * metric('RL') * metric('RC'))
    result["temporal"] = {
        "score": score,
        "exploitability": exploitability,
        "impact": impact,
    }

    if has_metric_group(CVSS3_METRICS_ENVIRONMENTAL):
        result["final"] = {
            "score": env_score,
            "exploitability": m_exploitability,
            "impact": m_impact,
        }
    else:
        result["final"] = {
            "score": score,
            "exploitability": exploitability,
            "impact": impact,
        }

    return result


def calculate_score_cvss3_0(vector) -> dict | None:
    try:
        values = parse_cvss3(vector, version='3.0')
    except ValidationError:
        return None

    def has_metric_group(group):
        return any(map(lambda m: m in values and values[m] != 'X', group.keys()))

    def metric(name, modified=False) -> Any:
        # First try modified metric, then original metric, then X (Not Definied)
        if modified:
            m = CVSS3_METRICS.get('M' + name, {}).get(values.get('M' + name))
            if m is not None and m != 'X':
                return m
        m = CVSS3_METRICS.get(name, {}).get(values.get(name))
        if m is not None:
            return m
        return CVSS3_METRICS.get(name, {}).get('X')

    scope_changed = metric('S', modified=True) == 'C'
    isc_base = 1 - (
        (1 - metric('C')) *
        (1 - metric('I')) *
        (1 - metric('A')))
    exploitability_base = 8.22 * metric('AV') * metric('AC') * metric(
        'PR')[metric('S')] * metric('UI')
    impact_base = 7.52 * (isc_base-0.029) - 3.25 * \
        pow((isc_base-0.02), 15) if scope_changed else 6.42 * isc_base
    score_base = 0.0 if impact_base <= 0 else (
        round_up(min(1.08 * (impact_base + exploitability_base), 10)) if scope_changed else
        round_up(min(impact_base + exploitability_base, 10))
    )
    score_temporal = round_up(score_base * metric('E')
                              * metric('RL') * metric('RC'))

    isc_modified = min(1 - (
        (1 - metric('C', modified=True) * metric('CR')) *
        (1 - metric('I', modified=True) * metric('IR')) *
        (1 - metric('A', modified=True) * metric('AR'))
    ), 0.915)
    impact_modified = 7.52 * (isc_modified - 0.029) - 3.25 * pow(isc_modified - 0.02, 15) if scope_changed else \
        6.42 * isc_modified
    exploitability_modified = 8.22 * metric('AV', modified=True) * metric('AC', modified=True) * metric(
        'PR', modified=True)[metric('S', modified=True)] * metric('UI', modified=True)
    score_environmental = 0.0 if impact_modified <= 0 else (
        round_up(min(1.08 * (impact_modified + exploitability_modified), 10)) if scope_changed else
        round_up(min(impact_modified + exploitability_modified, 10))
    )
    score_environmental = round_up(
        score_environmental * metric('E') * metric('RL') * metric('RC'))
    result = {
        "version": "3.0",
        "base": {
            "score": score_base,
            "exploitability": exploitability_base,
            "impact": impact_base,
        },
        "temporal": {
            "score": score_temporal,
            "exploitability": exploitability_base,
            "impact": impact_base,
        },
        "environmental": {
            "score": score_environmental,
            "exploitability": exploitability_modified,
            "impact": impact_modified,
        },
    }
    if has_metric_group(CVSS3_METRICS_ENVIRONMENTAL):
        result["final"] = result["environmental"]
    elif has_metric_group(CVSS3_METRICS_TEMPORAL):
        result["final"] = result["temporal"]
    else:
        result["final"] = result["base"]

    return result
